/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
 */
#include "roi_align_rotated_v2_tiling.h"
#include "register/op_def_registry.h"
#include "tiling/platform/platform_ascendc.h"
#include <cmath>

using namespace std;
namespace optiling {
const uint32_t TILE_NUM = 8;
static ge::graphStatus TilingForRoiAlignRotatedV2(gert::TilingContext* context)
{
    RoiAlignRotatedV2TilingData tiling;
    
    auto input_shape = context->GetInputShape(0)->GetStorageShape(); // [N, C, H, W]
    auto rois_shape = context->GetInputShape(1)->GetStorageShape(); // [n, 6]
    auto output_shape = context->GetOutputShape(0)->GetStorageShape(); // [n, C, pooled_height, pooled_width]

    uint32_t batch_size = input_shape.GetDim(0);
    uint32_t input_h = input_shape.GetDim(1);
    uint32_t input_w = input_shape.GetDim(2);
    uint32_t channels = input_shape.GetDim(3);
    uint32_t channels_aligned;
    if (static_cast<uint32_t>(channels % 8) == 0) {
        channels_aligned = channels;
    } else {
        channels_aligned = (static_cast<uint32_t>(channels / 8) + 1) * 8;
    }

    auto attrsPtr = context->GetAttrs();
    if (attrsPtr == nullptr) {
        return ge::GRAPH_FAILED;
    }

    float spatial_scale = *(attrsPtr->GetAttrPointer<float>(0));
    int32_t sampling_ratio = *(attrsPtr->GetAttrPointer<uint32_t>(1));
    int32_t pooled_height = *(attrsPtr->GetAttrPointer<uint32_t>(2));
    int32_t pooled_width = *(attrsPtr->GetAttrPointer<uint32_t>(3));
    bool aligned = *(attrsPtr->GetAttrPointer<bool>(4));
    bool clockwise = *(attrsPtr->GetAttrPointer<bool>(5));

    uint32_t rois_num = rois_shape.GetDim(1);
    if (rois_num == 0) {
        return ge::GRAPH_FAILED;
    }

    auto platform = context->GetPlatformInfo();
    if (platform == nullptr) {
        return ge::GRAPH_FAILED;
    }
    
    auto platform_info = platform_ascendc::PlatformAscendC(platform);
    uint64_t ub_total_size;
    platform_info.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub_total_size);
    uint32_t BLOCK_DIM = platform_info.GetCoreNumAiv();
    if (BLOCK_DIM == 0) {
        return ge::GRAPH_FAILED;
    }

    uint32_t rois_num_aligned;
    if (static_cast<uint32_t>(rois_num % 8) == 0) {
        rois_num_aligned = rois_num;
    } else {
        rois_num_aligned = (static_cast<uint32_t>(rois_num / 8) + 1) * 8;
    }

    uint32_t tail_num = rois_num_aligned - rois_num; // 获取计算完成后需要丢弃的rois数目
    uint32_t rois_num_per_Score = (rois_num_aligned / BLOCK_DIM / 8) * 8;
    uint32_t rois_num_per_Lcore = rois_num_per_Score + 8;
    uint32_t Score_num = (BLOCK_DIM * (8 + rois_num_per_Score) - rois_num_aligned) / 8;
    uint32_t Lcore_num = BLOCK_DIM - Score_num;
    
    if (rois_num_per_Score == 0) {
        BLOCK_DIM = BLOCK_DIM - Score_num;
    }
    if (rois_num_per_Lcore == 0) {
        BLOCK_DIM = BLOCK_DIM - Lcore_num;
    }

    float input_size = float(channels_aligned * sizeof(float)) / 32;
    uint32_t input_buffer_size = ceil(input_size) * 32;

    tiling.set_blockDim(BLOCK_DIM);
    tiling.set_ub_total_size(ub_total_size);
    tiling.set_tileNum(TILE_NUM);
    tiling.set_batch_size(batch_size);
    tiling.set_channels(channels);
    tiling.set_channels_aligned(channels_aligned);
    tiling.set_input_h(input_h);
    tiling.set_input_w(input_w);
    tiling.set_rois_num_aligned(rois_num_aligned);
    tiling.set_tail_num(tail_num);
    tiling.set_spatial_scale(spatial_scale);
    tiling.set_sampling_ratio(sampling_ratio);
    tiling.set_pooled_height(pooled_height);
    tiling.set_pooled_width(pooled_width);
    tiling.set_aligned(aligned);
    tiling.set_clockwise(clockwise);
    tiling.set_rois_num_per_Lcore(rois_num_per_Lcore);
    tiling.set_rois_num_per_Score(rois_num_per_Score);
    tiling.set_Lcore_num(Lcore_num);
    tiling.set_Score_num(Score_num);
    tiling.set_input_buffer_size(input_buffer_size);
    tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());

    context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
    context->SetBlockDim(BLOCK_DIM);
    context->SetTilingKey(1);
    size_t *currentWorkspace = context->GetWorkspaceSizes(1);
    currentWorkspace[0] = 0;

    return ge::GRAPH_SUCCESS;
}
}

namespace ge {
static ge::graphStatus InferShape(gert::InferShapeContext* context)
{
    const gert::Shape* input_shape = context->GetInputShape(0);
    const gert::Shape* rois_shape = context->GetInputShape(1);
    gert::Shape* output_shape = context->GetOutputShape(0);

    int64_t rois_num = rois_shape->GetDim(0);
    int64_t channels = input_shape->GetDim(3);
    auto attrsPtr = context->GetAttrs();
    uint32_t pooled_height = *(attrsPtr->GetAttrPointer<uint32_t>(2));
    uint32_t pooled_width = *(attrsPtr->GetAttrPointer<uint32_t>(3));
    
    *output_shape = {rois_num, pooled_height, pooled_width, channels};

    return GRAPH_SUCCESS;
}
static ge::graphStatus InferDataTypeRoiAlignRotatedV2(gert::InferDataTypeContext* context)
{
    const ge::DataType value_dtype = context->GetInputDataType(0);
    context->SetOutputDataType(0, value_dtype);
    return GRAPH_SUCCESS;
}
}

namespace ops {
class RoiAlignRotatedV2 : public OpDef {
public:
    explicit RoiAlignRotatedV2(const char* name) : OpDef(name)
    {
        this->Input("input")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND});
        this->Input("rois")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND});
        this->Output("output")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND});
        this->Attr("spatial_scale").AttrType(REQUIRED).Float();
        this->Attr("sampling_ratio").AttrType(REQUIRED).Int();
        this->Attr("pooled_height").AttrType(REQUIRED).Int();
        this->Attr("pooled_width").AttrType(REQUIRED).Int();
        this->Attr("aligned").AttrType(REQUIRED).Bool();
        this->Attr("clockwise").AttrType(REQUIRED).Bool();

        this->SetInferShape(ge::InferShape)
            .SetInferDataType(ge::InferDataTypeRoiAlignRotatedV2);

        this->AICore()
            .SetTiling(optiling::TilingForRoiAlignRotatedV2);
        this->AICore().AddConfig("ascend910b");
        this->AICore().AddConfig("ascend910_93");
    }
};

OP_ADD(RoiAlignRotatedV2);
}